"""
# -*- coding: utf-8 -*-
# @Time    : 2023/6/19 10:29
# @Author  : 王摇摆
# @FileName: ResNet50CSDN.py
# @Software: PyCharm
# @Blog    ：https://blog.csdn.net/weixin_44943389?type=blog
"""

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, utils
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.dataset import Dataset
from torchvision.transforms import transforms
from pathlib import Path
import cv2
from PIL import Image
import torch.nn.functional as F

transform = transforms.Compose([ToTensor(),
                                transforms.Normalize(
                                    mean=[0.5, 0.5, 0.5],
                                    std=[0.5, 0.5, 0.5]
                                ),
                                transforms.Resize((224, 224))
                                ])

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

testing_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform,
)
print('1. 数据集加载完成，预处理完毕')

# 保持数据集和测试机能完整划分
batch_size = 64 # 超参数batch_size
train_data = DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_data = DataLoader(dataset=testing_data, batch_size=batch_size, shuffle=True, drop_last=True)
print('2. 数据集已划分完成')


# 定义残差块
class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=[1, 1, 1], padding=[0, 1, 0], first=False) -> None:
        super(Bottleneck, self).__init__()
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=padding[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),  # 原地替换 节省内存开销
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=padding[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),  # 原地替换 节省内存开销
            nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=stride[2], padding=padding[2], bias=False),
            nn.BatchNorm2d(out_channels * 4)
        )

        # shortcut 部分
        # 由于存在维度不一致的情况 所以分情况
        self.shortcut = nn.Sequential()
        if first:
            self.shortcut = nn.Sequential(
                # 卷积核为1 进行升降维
                # 注意跳变时 都是stride==2的时候 也就是每次输出信道升维的时候
                nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=stride[1], bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )
        # if stride[1] != 1 or in_channels != out_channels:
        #     self.shortcut = nn.Sequential(
        #         # 卷积核为1 进行升降维
        #         # 注意跳变时 都是stride==2的时候 也就是每次输出信道升维的时候
        #         nn.Conv2d(in_channels, out_channels*4, kernel_size=1, stride=stride[1], bias=False),
        #         nn.BatchNorm2d(out_channels)
        #     )

    def forward(self, x):
        out = self.bottleneck(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


# 采用bn的网络中，卷积层的输出并不加偏置
class ResNet50(nn.Module):
    def __init__(self, Bottleneck, num_classes=10) -> None:
        super(ResNet50, self).__init__()
        self.in_channels = 64
        # 第一层作为单独的 因为没有残差快
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # conv2
        self.conv2 = self._make_layer(Bottleneck, 64, [[1, 1, 1]] * 3, [[0, 1, 0]] * 3)

        # conv3
        self.conv3 = self._make_layer(Bottleneck, 128, [[1, 2, 1]] + [[1, 1, 1]] * 3, [[0, 1, 0]] * 4)

        # conv4
        self.conv4 = self._make_layer(Bottleneck, 256, [[1, 2, 1]] + [[1, 1, 1]] * 5, [[0, 1, 0]] * 6)

        # conv5
        self.conv5 = self._make_layer(Bottleneck, 512, [[1, 2, 1]] + [[1, 1, 1]] * 2, [[0, 1, 0]] * 3)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)

    def _make_layer(self, block, out_channels, strides, paddings):
        layers = []
        # 用来判断是否为每个block层的第一层
        flag = True
        for i in range(0, len(strides)):
            layers.append(block(self.in_channels, out_channels, strides[i], paddings[i], first=flag))
            flag = False
            self.in_channels = out_channels * 4

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)

        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out


# 定义网络
res50 = ResNet50(Bottleneck)
print('3. ResNet50模型已创建完毕')

images, labels = next(iter(train_data))
print(images.shape)
img = utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64)])
plt.imshow(img)

# 使用GPU加速训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU可用")
else:
    device = torch.device("cpu")
    print("GPU不可用")

model = res50.to(device)
cost = torch.nn.CrossEntropyLoss()
print('4. 交叉熵损失函数已定义完成')
optimizer = torch.optim.Adam(model.parameters())
print('5. Adam优化器已定义完成')

print(len(train_data))
print(len(test_data))
epochs = 10
print('================模型开始训练======================')
for epoch in range(epochs):  # 外层循环，用于控制循环的epoch
    running_loss = 0.0
    running_correct = 0.0
    model.train()
    print("Epoch {}/{}".format(epoch + 1, epochs))
    print("-" * 10)
    for X_train, y_train in train_data:  # 内层循环，模型在训练集上学习
        # X_train,y_train = torch.autograd.Variable(X_train),torch.autograd.Variable(y_train)
        X_train, y_train = X_train.to(device), y_train.to(device)
        outputs = model(X_train)
        _, pred = torch.max(outputs.data, 1)
        optimizer.zero_grad()
        loss = cost(outputs, y_train)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_correct += torch.sum(pred == y_train.data)

    testing_correct = 0
    test_loss = 0
    model.eval()

    for X_test, y_test in test_data:  # 第二个for循环，模型在测试集上推理预测
        # X_test,y_test = torch.autograd.Variable(X_test),torch.autograd.Variable(y_test)
        X_test, y_test = X_test.to(device), y_test.to(device)
        outputs = model(X_test)
        loss = cost(outputs, y_test)
        _, pred = torch.max(outputs.data, 1)
        testing_correct += torch.sum(pred == y_test.data)
        test_loss += loss.item()
    print("Train Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Loss is::{:.4f} Test Accuracy is:{:.4f}%".format(
        running_loss / len(training_data), 100 * running_correct / len(training_data),
        test_loss / len(testing_data),
        100 * testing_correct / len(testing_data)
    ))
